{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 梯度迭代树回归（GBDT）"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from pyspark.ml import Pipeline\n",
    "from pyspark.ml.regression import GBTRegressor\n",
    "from pyspark.ml.feature import VectorIndexer\n",
    "from pyspark.ml.evaluation import RegressionEvaluator\n",
    " \n",
    "# Load and parse the data file, converting it to a DataFrame.\n",
    "data = spark.read.format(\"libsvm\").load(\"data/mllib/sample_libsvm_data.txt\")\n",
    " \n",
    "# Automatically identify categorical features, and index them.\n",
    "# Set maxCategories so features with > 4 distinct values are treated as continuous.\n",
    "featureIndexer =\\\n",
    "    VectorIndexer(inputCol=\"features\", outputCol=\"indexedFeatures\", maxCategories=4).fit(data)\n",
    " \n",
    "# Split the data into training and test sets (30% held out for testing)\n",
    "(trainingData, testData) = data.randomSplit([0.7, 0.3])\n",
    " \n",
    "# Train a GBT model.\n",
    "gbt = GBTRegressor(featuresCol=\"indexedFeatures\", maxIter=10)\n",
    " \n",
    "# Chain indexer and GBT in a Pipeline\n",
    "pipeline = Pipeline(stages=[featureIndexer, gbt])\n",
    " \n",
    "# Train model.  This also runs the indexer.\n",
    "model = pipeline.fit(trainingData)\n",
    " \n",
    "# Make predictions.\n",
    "predictions = model.transform(testData)\n",
    " \n",
    "# Select example rows to display.\n",
    "predictions.select(\"prediction\", \"label\", \"features\").show(5)\n",
    " \n",
    "# Select (prediction, true label) and compute test error\n",
    "evaluator = RegressionEvaluator(\n",
    "    labelCol=\"label\", predictionCol=\"prediction\", metricName=\"rmse\")\n",
    "rmse = evaluator.evaluate(predictions)\n",
    "print(\"Root Mean Squared Error (RMSE) on test data = %g\" % rmse)\n",
    " \n",
    "gbtModel = model.stages[1]\n",
    "print(gbtModel)  # summary only"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
